import math
import pdb

import numpy as np

def message_passing(X, A, p, q, d):
    # Predicted labels with message-passing algorithm
    base = math.log(p / q)
    thrld = base / (2 * d)
    L = 2 * d * X
    G = np.copy(L)
    G[X > thrld] = base
    G[X < -thrld] = -base
    # G = [g_function(X[i], p, q, d) for i in range(n)]
    # Predict with sign function
    Y = np.sign(L + A * G)
    return G, Y


def linear_convolution(X, A, d):
    # Predicted labels with single linear convolution
    L = 2 * d * X
    # Predict with sign function
    Y = np.sign(L + A * L)
    return L, Y

def optimal_nonlinear_propagation_general(X, A, p, q):
    # Predicted labels with message-passing algorithm
    thrld = math.log(p / q)
    G = np.copy(X)
    G[X > thrld] = thrld
    G[X < -thrld] = -thrld
    # G = [g_function(X[i], p, q, d) for i in range(n)]
    # Predict with sign function
    Y = np.sign(X + A * G)
    return G, Y

def psi_Laplace(X, mu, b, d):
    #psi function
    thrld = 2 * mu / b
    X[X > thrld] = thrld
    X[X < -thrld] = -thrld
    X = X @ np.ones(d)
    return X

def message_passing_heter(X, A, p, q, d):
    # Predicted labels with message-passing algorithm
    base = math.log(q / p)
    thrld = base / (2 * d)
    L = 2 * d * X
    G = - np.copy(L)
    G[X > thrld] = -base
    G[X < -thrld] = base
    G[(X >= -thrld) & (X <= thrld)] *= 1
    # G = [g_function(X[i], p, q, d) for i in range(n)]
    # Predict with sign function
    Y = np.sign(L + A * G)
    return G, Y

def linear_convolution_heter(X, A, d):
    # Predicted labels with single linear convolution
    L = 2 * d * X
    # Predict with sign function
    Y = np.sign(L - A * L)
    return L, Y